// Copyright (c) OpenMMLab. All rights reserved.

#include "reference.h"
#include "src/turbomind/kernels/attention/rotary_embedding.h"
#include "src/turbomind/kernels/core/array_ops.h"
#include "src/turbomind/kernels/unfused_attention_kernels.h"

namespace turbomind {

template<class T>
__global__ void
createCausalMasks(T* mask, const int* q_lens, const int* k_lens, int64_t max_q_len, int64_t max_k_len, int window_size)
{
    const int     bi      = blockIdx.x;
    const int64_t q_len   = q_lens ? q_lens[bi] : max_q_len;
    const int64_t k_len   = k_lens ? k_lens[bi] : max_k_len;
    const int     history = k_len - q_len;
    mask += bi * max_q_len * max_k_len;
    for (int64_t i = threadIdx.x; i < max_q_len * max_k_len; i += blockDim.x) {
        const int q = i / max_k_len;
        const int k = i % max_k_len;
        const int w = q - (k - history);

        const bool is_valid = q < q_len && k < k_len && 0 <= w && w < window_size;

        mask[i] = is_valid ? T{1.} : T{0.};
    }
}

// [B, H, S, D]
template<class T>
__global__ void
applyRotaryEmbedding(T* k_cache, int max_k_len, int head_num, int head_dim, float rope_base, int rope_dim)
{
    const int    ti = blockIdx.x;
    const size_t hi = blockIdx.y;
    const size_t bi = blockIdx.z;

    constexpr int kVecSize = 2;
    const int     history  = 0;

    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {
        const size_t idx =
            bi * head_num * max_k_len * head_dim + hi * max_k_len * head_dim + (history + ti) * head_dim + d;

        Array<T, kVecSize> vec_K;

        Load(vec_K, &k_cache[idx]);

        RotaryEmbedding<kVecSize> rope(rope_base, rope_dim, history + ti, {d, 0});

        rope.apply(vec_K);

        Store(&k_cache[idx], vec_K);
    }
}

template<class T>
void invokeApplyRotaryEmbedding(T*           k_cache,
                                int          max_k_len,
                                int          head_num,
                                int          head_dim,
                                float        rope_base,
                                int          rope_dim,
                                int          batch_size,
                                cudaStream_t stream)
{
    int  threads = 128;
    dim3 blocks(max_k_len, head_num, batch_size);

    applyRotaryEmbedding<<<blocks, threads, 0, stream>>>(k_cache, max_k_len, head_num, head_dim, rope_base, rope_dim);
}

template void invokeApplyRotaryEmbedding(half*        k_cache,
                                         int          max_k_len,
                                         int          head_num,
                                         int          head_dim,
                                         float        rope_base,
                                         int          rope_dim,
                                         int          batch_size,
                                         cudaStream_t stream);
#if ENABLE_BF16
template void invokeApplyRotaryEmbedding(nv_bfloat16* k_cache,
                                         int          max_k_len,
                                         int          head_num,
                                         int          head_dim,
                                         float        rope_base,
                                         int          rope_dim,
                                         int          batch_size,
                                         cudaStream_t stream);
#endif

template<class T>
__global__ void processQKV(T*       q_out,     // [B, H, s, D]
                           T*       k_cache,   // [B, H, S, D]
                           T*       v_cache,   // [B, H, S, D]
                           const T* qkv,       // [B, s, H, D]
                           const T* qkv_bias,  // [Q; K; V]
                           int      max_q_len,
                           int      max_k_len,
                           int      head_num,
                           int      head_dim,
                           int      kv_head_num,
                           float    rope_theta,
                           int      rope_dim)
{
    const int    ti = blockIdx.x;
    const size_t hi = blockIdx.y;
    const size_t bi = blockIdx.z;

    const int history = max_k_len - max_q_len;

    size_t qkv_head_num = head_num + 2 * kv_head_num;

    auto q = qkv + (bi * max_q_len + ti) * qkv_head_num * head_dim;
    auto k = q + head_num * head_dim;
    auto v = k + kv_head_num * head_dim;

    auto q_bias = qkv_bias ? qkv_bias + hi * head_dim : nullptr;
    auto k_bias = qkv_bias ? q_bias + head_num * head_dim : nullptr;
    auto v_bias = qkv_bias ? k_bias + kv_head_num * head_dim : nullptr;

    constexpr int kVecSize = 2;

    using namespace ops;

    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {
        const auto         idx = bi * head_num * max_q_len * head_dim + hi * max_q_len * head_dim + ti * head_dim + d;
        Array<T, kVecSize> vec;
        Ldg(vec, &q[hi * head_dim + d]);
        if (qkv_bias) {
            Array<T, kVecSize> bias;
            Load(bias, &q_bias[d]);
            vec = vec + bias;
        }
        if (rope_theta) {
            RotaryEmbedding<kVecSize> rope(rope_theta, rope_dim, history + ti, {d, 0});
            rope.apply(vec);
        }

        Store(&q_out[idx], vec);
    }

    if (hi >= kv_head_num) {
        return;
    }

    for (int d = threadIdx.x * kVecSize; d < head_dim; d += blockDim.x * kVecSize) {
        const auto idx =
            bi * kv_head_num * max_k_len * head_dim + hi * max_k_len * head_dim + (history + ti) * head_dim + d;
        Array<T, kVecSize> vec_K;
        Array<T, kVecSize> vec_V;
        Ldg(vec_K, &k[hi * head_dim + d]);
        Ldg(vec_V, &v[hi * head_dim + d]);
        if (qkv_bias) {
            Array<T, kVecSize> bias_K;
            Array<T, kVecSize> bias_V;
            Load(bias_K, &k_bias[d]);
            Load(bias_V, &v_bias[d]);
            vec_K = vec_K + bias_K;
            vec_V = vec_V + bias_V;
        }
        if (rope_theta) {
            RotaryEmbedding<kVecSize> rope(rope_theta, rope_dim, history + ti, {d, 0});
            rope.apply(vec_K);
        }
        Store(&k_cache[idx], vec_K);
        Store(&v_cache[idx], vec_V);
    }
}

template<class T>
__global__ void RepeatKVKernel(T*       keys,
                               T*       vals,
                               const T* k_cache,
                               const T* v_cache,
                               int      head_num,
                               int      max_k_len,
                               int      head_dim,
                               int      kv_head_num,
                               int      n_reps)
{
    const int64_t ti = blockIdx.x;
    const int64_t hi = blockIdx.y;
    const int64_t bi = blockIdx.z;

    const auto khi = hi / n_reps;

    // clang-format off
    for (int d = threadIdx.x; d < head_dim; d += blockDim.x) {
        int64_t d_idx = bi *    head_num * max_k_len * head_dim +  hi * max_k_len * head_dim + ti * head_dim + d;
        int64_t s_idx = bi * kv_head_num * max_k_len * head_dim + khi * max_k_len * head_dim + ti * head_dim + d;
        keys[d_idx] = k_cache[s_idx];
        vals[d_idx] = v_cache[s_idx];
    }
    // clang-format on
}

template<class T>
Reference<T>::Reference(cudaStream_t stream): stream_(stream)
{
    cublasCreate(&cublas_);
    cublasSetStream(cublas_, stream);
}

template<class T>
void Reference<T>::Reshape(size_t max_q_len,
                           size_t max_k_len,
                           size_t head_num,
                           size_t head_dim,
                           size_t kv_head_num,
                           size_t batch_size,
                           int    window_size)
{
    std::cout << max_q_len << " " << max_k_len << " " << head_num << " " << head_dim << " " << batch_size << "\n";

    q_.resize(batch_size * head_num * max_q_len * head_dim);
    mask_.resize(batch_size * max_q_len * max_k_len);

    std::cout << "size of QK buf: "
              << ((batch_size * head_num * max_q_len * max_k_len * sizeof(float)) / float(1 << 30)) << " GB\n";
    qk_.resize(batch_size * head_num * max_q_len * max_k_len);
    pr_.resize(batch_size * head_num * max_q_len * max_k_len);
    out_.resize(batch_size * max_q_len * head_num * head_dim);

    keys_.resize(batch_size * head_num * max_k_len * head_dim);
    vals_.resize(batch_size * head_num * max_k_len * head_dim);

    cudaStreamSynchronize(0);

    createCausalMasks<<<batch_size, 512, 0, stream_>>>(
        mask_.data().get(), nullptr, nullptr, max_q_len, max_k_len, window_size);

    max_q_len_   = max_q_len;
    max_k_len_   = max_k_len;
    head_num_    = head_num;
    head_dim_    = head_dim;
    kv_head_num_ = kv_head_num;
    batch_size_  = batch_size;
    window_size_ = window_size;
}

template<class T>
void Reference<T>::Execute(
    T* output, T* k_cache, T* v_cache, const T* qkv, const T* qkv_bias, const T* sinks, float rope_base, int rope_dim)
{
    {
        int  threads = 128;
        dim3 blocks(max_q_len_, head_num_, batch_size_);
        cudaDeviceSynchronize();

        processQKV<<<blocks, threads, 0, stream_>>>(q_.data().get(),  //
                                                    k_cache,
                                                    v_cache,
                                                    qkv,
                                                    qkv_bias,
                                                    max_q_len_,
                                                    max_k_len_,
                                                    head_num_,
                                                    head_dim_,
                                                    kv_head_num_,
                                                    rope_base,
                                                    rope_dim);

        // std::cout << head_num_ << " " << kv_head_num_ << " " << head_dim_ / kv_head_num_ << "\n";

        blocks.x = max_k_len_;
        RepeatKVKernel<<<blocks, threads, 0, stream_>>>(keys_.data().get(),
                                                        vals_.data().get(),
                                                        k_cache,
                                                        v_cache,
                                                        head_num_,
                                                        max_k_len_,
                                                        head_dim_,
                                                        kv_head_num_,
                                                        head_num_ / kv_head_num_);

        cudaDeviceSynchronize();
    }

    const cudaDataType data_type = std::is_same_v<T, half> ? CUDA_R_16F : CUDA_R_16BF;

    float alpha = 1.f / sqrtf((float)head_dim_);
    float beta  = 0.f;
    cublasGemmStridedBatchedEx(cublas_,
                               CUBLAS_OP_T,              // trans A
                               CUBLAS_OP_N,              // trans B
                               max_k_len_,               // m
                               max_q_len_,               // n
                               head_dim_,                // k
                               &alpha,                   // alpha
                               keys_.data().get(),       // A
                               data_type,                // A type
                               head_dim_,                // lda
                               max_k_len_ * head_dim_,   // strideA
                               q_.data().get(),          // B
                               data_type,                // B type
                               head_dim_,                // ldb
                               max_q_len_ * head_dim_,   // stride B
                               &beta,                    // beta
                               qk_.data().get(),         // C
                               CUDA_R_32F,               // C type
                               max_k_len_,               // ldc
                               max_q_len_ * max_k_len_,  // stride C
                               batch_size_ * head_num_,  // batch count
                               CUBLAS_COMPUTE_32F,       // compute type
                               CUBLAS_GEMM_DEFAULT);

    MaskedSoftmaxParam<T> params{};
    params.attention_score = pr_.data().get();
    params.qk              = qk_.data().get();
    params.attention_mask  = mask_.data().get();
    params.batch_size      = batch_size_;
    params.q_length        = max_q_len_;
    params.k_length        = max_k_len_;
    params.num_heads       = head_num_;
    params.sinks           = sinks;
    invokeMaskedSoftmax(params, stream_);

    alpha = 1.f;
    cublasGemmStridedBatchedEx(cublas_,
                               CUBLAS_OP_N,              // trans A
                               CUBLAS_OP_N,              // trans B
                               head_dim_,                // m
                               max_q_len_,               // n
                               max_k_len_,               // k
                               &alpha,                   // alpha
                               vals_.data().get(),       // A
                               data_type,                // A type
                               head_dim_,                // lda
                               max_k_len_ * head_dim_,   // strideA
                               pr_.data().get(),         // B
                               data_type,                // B type
                               max_k_len_,               // ldb
                               max_q_len_ * max_k_len_,  // stride B
                               &beta,                    // beta
                               out_.data().get(),        // C [b, h, q, d]
                               data_type,                // C type
                               head_dim_,                // ldc
                               max_q_len_ * head_dim_,   // stride C
                               batch_size_ * head_num_,  // batch count
                               CUBLAS_COMPUTE_32F,       // compute type
                               CUBLAS_GEMM_DEFAULT);

    // [B, H, Q, D] -> [B, Q, H, D]
    invokeTransposeAttentionOutRemovePadding(out_.data().get(),
                                             output,
                                             batch_size_ * max_q_len_,
                                             batch_size_,
                                             max_q_len_,
                                             head_num_,
                                             head_dim_,
                                             nullptr,
                                             nullptr,
                                             0,
                                             stream_);
}

template class Reference<half>;

#if ENABLE_BF16
template class Reference<nv_bfloat16>;
#endif

}  // namespace turbomind
